# -*- coding: utf-8 -*-
"""
Created on Fri Oct 18 10:19:02 2024.

@author: James Maldaner
@email: maldaner@ualberta.ca
"""
# Import statements
from dataclasses import dataclass
import unittest
import json

import numpy as np
import scipy.constants as const

from scipy.optimize import fsolve



# Physical Constants
pi: float = np.pi

c: float = const.c  # [m/s] Speed of light
h: float = const.h  # [J/Hz] Plank constant
e_charge: float = const.elementary_charge  # [C] Electronic charge
k_B: float = const.k  # [J/K] Boltzman constant
hbar: float = const.hbar  # [Js] Reduced Plank constant

exp = np.exp
sqrt = np.sqrt

# Load configuration files

# Data file for the photodetector
pd_filename = 'thorlabs_config.json'

# Data file for the laser
laser_filename = 'laser_config.json'

# Data file for the xenon
xenon_filename = 'xenon_config.json'

# Data file for other experimental details
exp_filename = 'long_pulse_config.json'

config_folder = 'config/'
with open(config_folder + pd_filename) as file:
    pd_data = json.loads(file.read())
with open(config_folder + laser_filename) as file:
    laser_data = json.loads(file.read())
with open(config_folder + xenon_filename) as file:
    xe_data = json.loads(file.read())
with open(config_folder + exp_filename) as file:
    exp_data = json.loads(file.read())



# Math Functions
# Conversions
def pressure_bar_to_pascal(p_bar: float) -> float:
    """
    Convert bar to pascal.

    Parameters
    ----------
    p_bar : float
        Pressure in units of bar.

    Returns
    -------
    flost
        Pressure in units of pascal.

    """
    return p_bar * 1e5


def pressure_bar_to_millibar(p_bar: float) -> float:
    """
    Convert pressure in bar to millibar.

    Parameters
    ----------
    p_bar : float
        Pressure in units of bar.

    Returns
    -------
    float
        Pressure in units of millibar.

    """
    return p_bar * 1e3



def wavelength_to_omega(wavelength: float) -> float:
    """
    Convert light wavelength to angular frequency.

    Parameters
    ----------
    wavelength : float
        Wavelength in meters.

    Returns
    -------
    float
        Angular frequency in rad/s.

    """
    return 2 * pi * c / wavelength



def RIN_dB_to_ratio(RIN: float) -> float:
    """
    Convert RIN in dBc/Hz to 1/Hz.

    Parameters
    ----------
    RIN : float
        RIN in dBc/Hz.

    Returns
    -------
    float
        RIN in 1/Hz.

    """
    return 10**(RIN/10)



def RIN_ratio_to_dB(RIN: float) -> float:
    """
    Convert RIN in 1/Hz to dBc/Hz.

    Parameters
    ----------
    RIN : float
        RIN in 1/Hz.

    Returns
    -------
    float
        RIN in dBc/Hz.

    """
    return 10 * np.log10(RIN)


@dataclass
class calculation:
    """Performs the calculations assosiated with a magnetometry experiment."""
    Ppol: float = exp_data['Ppol']  # [] Initial degree of xenon spin polarization
    eta_isotope: float = exp_data['eta_isotope']  # [] Isotopic fraction of 129Xe
    pres_Xe_bar: float = exp_data['pres_Xe_bar']  # [Bar] xenon pressure in units of bar
    T_Xe: float = exp_data['T_Xe']  # [K] Xenon temperature
    T2_Xe: float = exp_data['T2_Xe']  # [sec] Xenon spin T2 decay time
    tau_pulse_BL: float = exp_data['tau_pulse_BL']  # [s] bandwidth limited laser pulse duration
    w0: float = exp_data['w0']  # [m] laser beam radius 2e-3
    geom_efficiency: float = exp_data['geom_efficiency']  # [] Detection efficiency
    T_m: float = exp_data['T_m']  # [sec] Measurement time
    N_meas: float = exp_data['N_meas']  # [] Number of measurements

    alpha_TP: float = xe_data['alpha_TP']  # [m^4/J^2] Two-photon absorption coefficient
    gamma_Xe_nat: float = xe_data['gamma_Xe_nat']  # [Hz] Natural width of the 2photon transition
    gamma_Xe_pres_per_mbar: float = xe_data['gamma_Xe_pres_per_mbar']  # [Hz/mbar] Pressure broadening
    gm_ratio_Xe: float = xe_data['gm_ratio_Xe']  # [rad * Hz/T] Gyromagnetic ratio


    P_laser: float = laser_data['P_laser']  # [W] UV laser power
    RIN: float = laser_data['RIN']

    Resp: float = pd_data['Resp']  # [A/W] Max responsivity
    Resp_ratio = pd_data['Resp_ratio']
    # Relative response of the photodetectors between the two wavelengths.
    NEP: float = pd_data['NEP']  # [W/vHz] Min NEP
    R_TIA: float = pd_data['R_TIA']  # [V/A] amplification of TIA
    RCMRR: float = pd_data['RCMRR']

    magnetic_field: float = 0.15e-15  # [T] Magnetic field strength

    pd_data_file: str = pd_data['filename']
    laser_data_file: str = laser_data['filename']
    xe_data_file: str = xe_data['filename']
    exp_data_file: str = exp_data['filename']

    def __post_init__(self):
        """After dataclass initiation, run the calculations."""
        self.update_calculations()


    def update_calculations(self):
        """
        Recompute all derived physical quantities and noise contributions.

        Execution order:
        1. Gas and linewidth parameters
        2. Atomic population and interaction parameters
        3. Optical transition rates
        4. Signal amplitude
        5. Noise contributions
        6. Total sensitivity
        """
        self.pres_Xe = pressure_bar_to_pascal(self.pres_Xe_bar)
        # [Pascal] xenon pressure in units of Pascal=N/m^2
        self.pres_Xe_millibar = pressure_bar_to_millibar(self.pres_Xe_bar)
        # [millibar] xenon pressure in units of millibar
        self.gamma_Xe_pres = self.calc_spect_width_pressure_broadened()
        # [Hz] Pressure broadened spectral width
        self.gamma = self.gamma_Xe_nat + self.gamma_Xe_pres
        # [Hz] Total spectral width

        self.density_Xe = self.calc_density_Xe()
        # [1/m^3] Density of the xenon 129 atoms
        self.length = self.calc_interaction_length()
        # [m] Interaction length
        self.g_0 = self.calc_g_0()
        # [1/Hz] Lineshape parameter = 1/gamma


        self.det_eff = self.det_efficiency_calc()
        # [] Detector efficiency


        if self.w0 == 0:
            # If w0 is not specified, find the optimal w0
            self.w0 = self.opt_w0()
            # [m] Optimized beam waist

        self.N_Xe = self.calc_num_Xe_atoms()
        # [] Number of xenon 129 atoms involved in the interaction
        self.intensity_laser = self.calc_laser_intensity()
        # [W/m^2] laser intensity
        self.W_2 = self.calc_two_photon_rate_per_atom()
        # [Hz] Two photon interaction rate per atom
        self.n_bar_f = self.calc_IR_photon_mean_rate()


        # Projection Noise
        self.sigma_omega_B_PN_squared = self.calc_sigma_omega_B_PN_squared()
        self.sensitivity_PN = (
            self.sensitivity_calc(self.sigma_omega_B_PN_squared))

        # Shot noise
        self.sigma_omega_B_SN_IR_squared = (
            self.calc_sigma_omega_B_SN_IR_squared())
        self.sensitivity_SN_IR = (
            self.sensitivity_calc(self.sigma_omega_B_SN_IR_squared))

        self.sigma_omega_B_SN_LPUV_squared = (
            self.calc_sigma_omega_B_SN_LPUV_squared())
        self.sensitivity_SN_LPUV = (
            self.sensitivity_calc(self.sigma_omega_B_SN_LPUV_squared))


        self.sigma_omega_B_SN_squared = (self.sigma_omega_B_SN_IR_squared +
                                         self.sigma_omega_B_SN_LPUV_squared )

        self.sensitivity_SN = (
            self.sensitivity_calc(self.sigma_omega_B_SN_squared))

        # Fundamental Noise
        self.sigma_omega_B_FN_squared = (self.sigma_omega_B_SN_squared
                                         + self.sigma_omega_B_PN_squared)
        self.sensitivity_FN = self.sensitivity_calc(
            self.sigma_omega_B_FN_squared)
        # Detector Noise
        self.sigma_omega_B_DN_NEP_squared = (
            self.calc_sigma_omega_B_DN_NEP_squared())
        self.sensitivity_DN_NEP = (
            self.sensitivity_calc(self.sigma_omega_B_DN_NEP_squared))

        self.sigma_omega_B_DN_TN_squared = (
            self.calc_sigma_omega_B_DN_TN_squared())
        self.sensitivity_DN_TN = (
            self.sensitivity_calc(self.sigma_omega_B_DN_TN_squared))

        self.sigma_omega_B_DN_squared = (self.sigma_omega_B_DN_NEP_squared  #
                                         + self.sigma_omega_B_DN_TN_squared)

        self.sensitivity_DN = (
            self.sensitivity_calc(self.sigma_omega_B_DN_squared))

        # Intensity Noise
        self.sigma_omega_B_IN_squared = self.calc_sigma_omega_B_RIN_squared()
        self.sensitivity_IN = (
            self.sensitivity_calc(self.sigma_omega_B_IN_squared))

        # Technical Noise
        self.sigma_omega_B_TN_squared = (self.sigma_omega_B_IN_squared +
                                         self.sigma_omega_B_DN_squared)

        self.sensitivity_TN = self.sensitivity_calc(self.sigma_omega_B_TN_squared)

        self.sigma_omega_B_squared = (self.sigma_omega_B_PN_squared +
                                      self.sigma_omega_B_SN_squared +
                                      self.sigma_omega_B_DN_squared +
                                      self.sigma_omega_B_IN_squared)
        self.sensitivity = self.sensitivity_calc(self.sigma_omega_B_squared)


    def calc_spect_width_pressure_broadened(self) -> float:
        """
        Find pressure broadened width in Hz.

        Parameters
        ----------
        p_millibar : float
            Pressure in millibar.

        Returns
        -------
        float
            Spectral width in Hz.

        """
        p_millibar = self.pres_Xe_millibar
        gamma_Xe_pressure_per_mbar = self.gamma_Xe_pres_per_mbar
        return gamma_Xe_pressure_per_mbar * p_millibar


    def calc_num_Xe_atoms(self) -> float:
        """
        Calculate the number of xenon 129 atoms in the interaction.

        Parameters
        ----------
        length : float
            Interaction length in meters.
        w0 : float
            Laser beam radius in meters.
        density : float
            Spatial density of xenon 129 atoms in 1/m^3.

        Returns
        -------
        float
            Number of xenon 129 atoms involved in the interaction.

        """
        length = self.length
        w0 = self.w0
        density = self.density_Xe
        return length * pi * w0 ** 2 * density


    def calc_density_Xe(self) -> float:
        """
        Find the density of xenon 129 atoms in the glass cell.

        Parameters
        ----------
        eta_isotope : float
            Ratio of xenon 129 in the xenon gas.
        pressure_Xe : float
            Xenon pressure in pascal.
        T_Xe : float
            Temperature of xenon  in kelvin.

        Returns
        -------
        float
            Density of xenon 129 atoms in 1/m^3.

        """
        eta_isotope = self.eta_isotope
        pressure_Xe = self.pres_Xe
        T_Xe = self.T_Xe
        return eta_isotope * pressure_Xe / (k_B * T_Xe)


    def calc_interaction_length(self) -> float:
        """
        Calculate the interaction length for the doppler free case.

        Parameters
        ----------
        tau_pulse : float
            Pulse duration in seconds.

        Returns
        -------
        float
            Pulse length in meters.

        """
        tau_pulse = self.tau_pulse_BL
        return c * tau_pulse


    def calc_larmor_frequency(self) -> float:
        """
        Calculate the larmor frequency.

        Parameters
        ----------
        gamma : float
            Gyromagnetic Ratio in (rad/sec) per Tesla.
        magnetic_field : float
            Magnetic field in teslas.

        Returns
        -------
        float
            Larmor frequency in rad/s.

        """
        gamma = self.gm_ratio_Xe
        magnetic_field = self.magnetic_field
        return gamma * magnetic_field


    def calc_laser_intensity(self) -> float:
        """
        Calculate the average laser intensity in W/m^2.

        Parameters
        ----------
        power : float
            Laser average power in W.
        w0 : float
            Laser beam waist radius in m.

        Returns
        -------
        float
            Average laser intensity in W/m^2.

        """
        power = self.P_laser
        w0 = self.w0
        return power / (pi * w0 ** 2)


    def calc_two_photon_rate_per_atom(self) -> float:
        """
        Calculate the photon transition rate in the doppler free interaction.

        Note: The coefficient differs from what is stated in the paper. The additional factor of 4 is included in the g_0 calculation.
        Parameters
        ----------
        alpha_Xe : float
            Two photon absorption coefficient in m^4/J^2.
        g_0 : float
            Lineshape on resonance for the xenon transition in 1/Hz.
        intensity : float
            Average laser intensity W/m^2.

        Returns
        -------
        float
            Two-photon transition rate in Hz.

        """
        alpha_Xe = self.alpha_TP
        g_0 = self.g_0
        intensity = self.intensity_laser
        return 6 * alpha_Xe * intensity ** 2 * g_0


    def calc_g_0(self) -> float:
        """
        Doppler free lineshape on resonance.
        See reference
        T. D. Raymond, N. B¨owering, C.-Y. Kuo, and J. W.
        Keto, Two-photon laser spectroscopy of xenon collision
        pairs, Phys. Rev. A 29, 721 (1984).

        Parameters
        ----------
        gamma : float
            Xenon 129 linesidth in Hz.

        Returns
        -------
        float
            Lineshape in 1/Hz.

        """
        gamma = self.gamma
        return 4 / gamma


    def calc_HP_UV_RIN(self) -> float:
        """
        Calculate the RIN due to shot noise in high power UV signal.

        Parameters
        ----------
        P_l : float
            laser power in W.

        Returns
        -------
        float
            RIN equivalent to shot noise in dBc/Hz.

        """
        lambda_256 = 256e-9
        omega_256 = 2 * pi * c / lambda_256
        P_l = self.P_laser
        return RIN_ratio_to_dB((hbar * omega_256)/(P_l))


    def calc_IR_photon_mean_rate(self) -> float:
        """
        Calculate the mean detected fluorescence rate (nbar_f) in photon/s.

        Parameters
        ----------
        W_2 : float
            Two-photon transition rate in Hz.
        length : float
            Interaction length in m.
        w0 : float
            Laser beam waist radius in m.
        density_Xe : float
            Xenon 129 density in 1/m^3.
        det_efficiency : float
            Detection efficiency of the experimental setup and photodetectors.

        Returns
        -------
        float
            Mean photon detection rate in Hz.

        """
        W_2 = self.W_2
        N_Xe = self.N_Xe
        return 0.5 * N_Xe * W_2


    def quantum_efficiency_calc(self) -> float:
        responsivity = self.Resp
        omega_IR = 2 * pi * c / (992e-9)
        return responsivity * hbar * omega_IR / e_charge


    def det_efficiency_calc(self) -> float:
        geom_efficiency = self.geom_efficiency
        quantum_efficiency = self.quantum_efficiency_calc()
        return geom_efficiency * quantum_efficiency


    def sensitivity_calc(self, sigma_omega_B_squared: float) -> float:
        """
        Sensitivity from variance in larmor frequency.

        Parameters
        ----------
        sigma_omega_B_squared : float
            Variance in the larmor frequency in (rad/s)^2.
        T_m : float
            Measurement time in seconds.
        N_m : float
            Number of measurements.
        gamma : float
            Gyromagnetic ratio of xenon 129 in (rad/s)/tesla.

        Returns
        -------
        float
            Sensitivity in T/vHz.

        """
        T_m = self.T_m
        N_m = self.N_meas
        gamma = self.gm_ratio_Xe
        return np.sqrt(T_m / N_m) * np.sqrt(sigma_omega_B_squared) / gamma


    def calc_sigma_omega_B_PN_squared(self) -> float:
        """
        Calculate the variance in  detection due to projection noise.

        Parameters
        ----------
        T_2 : float
            Dephasing time of the xenon  in seconds.
        N_Xe : float
            The number of xenon 129 atoms involved in the interaciton.
        T_m : float
            The measurement time in seconds.

        Returns
        -------
        float
            Variance in the larmor frequency due to the projection noise.

        """
        T_2 = self.T2_Xe
        N_Xe = self.N_Xe
        T_m = self.T_m
        return (2 * pi) ** 2 / (T_2 * N_Xe * T_m)


    def calc_sigma_omega_B_squared(self, n_A_squared: float) -> float:
        """
        Calculate th variance in detection due to noise in the current.

        Parameters
        ----------
        T_2 : float
            Dephasing time in seconds.
        T_m : float
            Measurement time in seconds.
        n_A_squared : float
            RMS noise spectral density amplitude squared at the omega_B in
            Amp^2/Hz.
        sinusoid_A : float
            Initial amplitude of the detected signal in Amps.

        Returns
        -------
        float
            Variance in the larmor frequency detection in Hz^2.

        """
        T_2 = self.T2_Xe
        T_m = self.T_m
        A = self.sinusoid_A()
        return (6 * n_A_squared * T_2 * (exp(2*T_m/T_2)-1) /
                (A ** 2 * T_m ** 4))


    def sinusoid_A(self) -> float:
        """
        Calculate the initial amp of the non-decaying sinusoidal component.

        Parameters
        ----------
        n_bar_f : float
            Mean detection rate of the IR photons in Hz.
        Ppol : float
            Initial polarization of the xenon 129 atoms.

        Returns
        -------
        float
            Initial amplitude of the non-decaying sinusoidal component of the
            detected signal in amps.

        """
        n_bar_f = self.n_bar_f
        Ppol = self.Ppol
        eff = self.det_eff
        return e_charge * n_bar_f * Ppol * eff


    def calc_sigma_omega_B_SN_IR_squared(self) -> float:
        """
        Calculate the variance due to shot noise in the IR signal.

        Parameters
        ----------
        T_2 : float
            Dephasing time in seconds.
        T_m : float
            Measurement time in seconds.
        Ppol : float
            Initial polarization of the xenon 129.
        n_bar_f : float
            Average photon detection rate in Hz.

        Returns
        -------
        float
            Variance in the larmor freuqnency due to shot noise in the IR
            signal in Hz^2.

        """
        n_bar_f = self.n_bar_f
        eff = self.det_eff
        n_A_squared = 2 * e_charge ** 2 * n_bar_f * eff
        return self.calc_sigma_omega_B_squared(n_A_squared)


    def calc_sigma_omega_B_SN_LPUV_squared(self) -> float:
        """
        Calculate the  variance due to shot noise in the low power UV signal.

        Parameters
        ----------
        T_2 : float
            Dephasing time in seconds.
        T_m : float
            Measurement time in seconds.
        Ppol : float
            Initial polarizatino of the xenon 129 atoms.
        n_bar_f : float
            Mean IR photon detection rate.

        Returns
        -------
        float
            Variance in the larmor freuqnency due to shot noise in the IR
            signal in Hz^2.

        """
        n_bar_f = self.n_bar_f
        eff = self.det_eff
        n_A_squared = 2 * e_charge ** 2 * n_bar_f * eff
        return self.calc_sigma_omega_B_squared(n_A_squared)


    def calc_sigma_omega_B_DN_NEP_squared(self) -> float:
        """
        Variance in the larmor frequency due to the NEP in the detector.

        Parameters
        ----------
        T_2 : float
            Dephasing time in seconds.
        T_m : float
            Measurement time in seconds.
        NEP_min : float
            Minimum noise equivalent power in W/vHz.
        R_max : float
            Maximum repsonsivity in A/W.

        Returns
        -------
        float
            Variance in the larmor frequency due to NEP in the detector in
            Hz^2.

        """
        NEP_min = self.NEP
        R_max = self.Resp
        n_A_squared = 2 * (NEP_min * R_max) ** 2
        return self.calc_sigma_omega_B_squared(n_A_squared)


    def calc_sigma_omega_B_DN_TN_squared(self) -> float:
        """
        Variance due to thermal noise in the TIA in Hz^2.

        Parameters
        ----------
        T_2 : float
            Dephasing time in seconds.
        T_m : float
            Measurement time in seconds.
        Ppol : float
            Initial polarization of the xenon 129 atoms.
        n_bar_f : float
            Mean detection rate of the IR photons in Hz.
        Temp : float
            Temperature of the balanced detector in Kelvin.
        R_TIA : float
            Resistance in the TIA in Ohm.

        Returns
        -------
        float
            Variance due to the thermal noise in the balanced detector in Hz^2.

        """
        Temp = self.T_Xe
        R_TIA = self.R_TIA
        n_A_squared = 4 * k_B * Temp / R_TIA
        return self.calc_sigma_omega_B_squared(n_A_squared)


    def calc_sigma_omega_B_RIN_squared(self) -> float:
        """
        Variance due to RIN due to technical noise.

        We take into account the common mode relative rejection rate and the
        specified RIN.

        Parameters
        ----------
        T_2 : float
            Dephasing time in seconds.
        T_m : float
            Measurement time in seconds.
        Ppol : float
            Initial polarization of the xenon 129 atoms.
        n_bar_f : float
            Mean detection rate of the IR photons in Hz.
        RIN : float
            Relative intensity noise due to technical noise in dBc/Hz.
        RCMRR : float
            Correction to RIN due to balanced detector in dBc/Hz.

        Returns
        -------
        float
            Variance in the larmor frequency detection due to RIN, corrected
            by the balanced detector, in Hz^2.

        """

        RIN = self.RIN
        RCMRR = self.RCMRR
        n_bar_f = self.n_bar_f
        eff = self.det_eff
        R_RIN = RIN_dB_to_ratio(RIN)
        R_CMRR = RIN_dB_to_ratio(RCMRR)
        R_RIN_final = 5 * R_RIN / (R_CMRR)  # From combining equations 25 and 27
        n_A_squared = e_charge ** 2 * n_bar_f ** 2 * eff ** 2 * R_RIN_final
        return self.calc_sigma_omega_B_squared(n_A_squared)


    def opt_w0(self) -> float:
        """
        Find the optimal beam waist.

        Parameters
        ----------
        T_2 : float
            Dephasing time in seconds.
        T_m : float
            Measurement time in seconds.
        Ppol : float
            Initial xenon 129 polarization.
        alpha_Xe : float
            Two-photon interaction coefficient in m^4/J^2.
        g_0 : float
            Transition linewidth in 1/Hz.
        power : float
            Laser power in W.
        n_Xe : float
            Density of xenon 129 atoms in 1/m^3.
        length : float
            Interaction length in m.
        eff_flo : float
            Detection efficiency.
        Resp : float
            Max responsivity of the photodetectors in A/W.
        NEP : float
            Minimum noise equivalent power in W/vHz.
        T_Xe : float
            Temperature of the balanced detector in Kelvin.
        R_det : float
            Resistor in the TIA in Ohm.

        Returns
        -------
        float
            Optimized beam waist.

        """
        ###
        T_2 = self.T2_Xe
        T_m = self.T_m
        Ppol = self.Ppol
        alpha_Xe = self.alpha_TP
        g_0 = self.g_0
        gamma = 4 / g_0
        power = self.P_laser
        n_Xe = self.density_Xe
        length = self.length
        eff_flo = self.det_eff
        Resp = self.Resp
        NEP = self.NEP
        ASN = (2 * pi * T_2 * (np.exp(2 * T_m / T_2) - 1) * gamma /
               (alpha_Xe * power ** 2 * n_Xe * length *
                eff_flo  * Ppol ** 2 * T_m ** 4 ))

        APN = (2 * pi)**2 / (T_2 * T_m * n_Xe * pi * length )

        ADN = ((T_2 * (NEP * Resp) ** 2 * (np.exp(2*T_m/T_2)-1) * pi**2 *
                gamma**2 ) / (12 * e_charge ** 2 * power ** 4 * alpha_Xe ** 2 *
                              length ** 2 * n_Xe ** 2 * Ppol ** 2 * T_m ** 4))

        def dif_function(w_0):
            return -2 * APN + 2 * ASN * w_0 ** 4  + 4 * ADN * w_0**6
        # Try to numerically solve with an initial value of 0.01 m, and if that doesn't work, try with 1 mm
        try:
            w0 = fsolve(dif_function, 1e-2)[0]
        except RuntimeWarning:
            print('Got warning for beam waist optimizaiton')
            w0 = fsolve(dif_function, 1e-3)[0]

        # Ensure the value of the beam waist is withing the allowable range
        if w0 < 100e-6:
            w0 = 100e-6
        elif w0 > 1e-2:
            w0 = 1e-2
        return w0





def main():
    """Run tests."""
    test = calculation()


if __name__ == '__main__':
    main()
